package com.raos.websocket2.handler;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.*;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
 * 自定义 MyHandler
 *
 * @author raos
 * @email 991207823@qq.com
 * @date 2021/10/16 9:35
 */
public class MyHandler implements WebSocketHandler {

    /** 在线用户列表 */
    private static final Map<String, WebSocketSession> users = new HashMap<>();

    /** 新增socket */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        System.out.println("成功建立连接 ...");
        String ID = session.getUri().toString().split("ID=")[1];
        System.out.println("用户ID=" + ID);
        if (ID != null) {
            users.put(session.getId() + "$" + ID, session);
            session.sendMessage(new TextMessage("成功建立socket连接 ..."));
            System.out.println("当前存储key=" + session.getId() + "$" + ID);
            System.out.println(session);
        }
        System.out.println("当前在线人数: " + users.size() + "\n");
    }

    /** 接收socket信息 */
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        try {
            JSONObject jsonobject = JSON.parseObject(message.getPayload().toString());
            System.out.println("当前id=" + jsonobject.get("id"));
            System.out.println("来自" + session.getAttributes().get("WEBSOCKET_USERID") + "的消息: "
                    + jsonobject.get("message"));
            sendMessageToUser(session.getId() + "$" + jsonobject.get("id"),
                    new TextMessage("服务器收到信息: " + jsonobject.get("message")));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 发送信息给指定用户
     *
     * @param clientId 指定的会话用户
     * @param message 要发送的信息
     * @return
     */
    public boolean sendMessageToUser(String clientId, TextMessage message) {
        if (users.get(clientId) == null) { return false; }
        WebSocketSession session = users.get(clientId);
        System.out.println("sendMessage webSocketSession: " + session + "\n");
        if (!session.isOpen()) { return false; }
        try {
            session.sendMessage(message);
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }
        return true;
    }

    /**
     * 广播信息
     *
     * @param message 要发送的信息
     * @return
     */
    public boolean sendMessageToAllUsers(TextMessage message) {
        boolean allSendSuccess = true;
        Set<String> clientIds = users.keySet();
        WebSocketSession session = null;
        for (String clientId : clientIds) {
            try {
                session = users.get(clientId);
                if (session.isOpen()) {
                    session.sendMessage(message);
                }
            } catch (IOException e) {
                e.printStackTrace();
                allSendSuccess = false;
            }
        }
        return allSendSuccess;
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable throwable) throws Exception {
        if (session.isOpen()) {
            session.close();
        }
        System.out.println("连接出错...");
        users.remove(getClientId(session));
        System.out.println("当前在线人数: " + users.size() + "\n");
    }

    /**
     * 获取用户标识
     * @param session
     * @return
     */
    private String getClientId(WebSocketSession session) {
        try {
            return session.getId() + "$" + session.getAttributes().get("WEBSOCKET_USERID");
        } catch (Exception e) {
            return null;
        }
    }

    /** 关闭连接后执行的方法 */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        System.out.println("连接已关闭：" + status);
        users.remove(getClientId(session));
        System.out.println("当前在线人数: " + users.size() + "\n");
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

}
